{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing Assumptions in model with DoWhy: A simple example\n",
    "This is a quick introduction to how we can test if our assumed graph is correct and the assumptions match with the dataset.\n",
    "We do so by checking the conditional independences in the graph and see if they hold true for the data as well. Currently we are using partial correlation to test continuous data and conditional mutual information to test discrete data.\n",
    "\n",
    "First, let us load all required packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os, sys\n",
    "sys.path.append(os.path.abspath(\"../../../\"))\n",
    "import dowhy\n",
    "from dowhy import CausalModel\n",
    "import dowhy.datasets "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Load dataset\n",
    "The function `dataset_from_random_graph(num_vars, num_samples, prob_edge, random_seed, prob_type_data)` can be used to create a\n",
    "dataset from a randomly generated directed graph. The data can be mixture of discrete, binary and continuous variables. The parameters have the following significance -\n",
    "- num_vars : Number of variables in the dataset\n",
    "- num_samples: Number of samples in the dataset\n",
    "- prob_edge : Probability of an edge between two random nodes in a graph\n",
    "- random_seed: Seed for generating random graph\n",
    "- prob_type_of_data : 3-element tuple containing the probability of data being discrete, binary and continuous respectively.(Should add up to 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = dowhy.datasets.dataset_from_random_graph(num_vars = 10,\n",
    "                                                num_samples = 5000,\n",
    "                                                prob_edge = 0.3,\n",
    "                                                random_seed = 100,\n",
    "                                                prob_type_of_data = (0.333, 0.333, 0.334)) \n",
    "df = data[\"df\"] #Insert dataset here\n",
    "print(data[\"discrete_columns\"], data[\"continuous_columns\"], data[\"binary_columns\"])\n",
    "print(df.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that we are using a pandas dataframe to load the data. At present, DoWhy only supports pandas dataframe as input."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Input causal graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now input a causal graph. You can do that in the GML graph format (recommended), DOT format or the output from daggity -\n",
    "To create the causal graph for your dataset, you can use a tool like [DAGitty](http://dagitty.net/dags.html#) that provides a GUI to construct the graph. You can export the graph string that it generates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_string = \"\"\"graph [\n",
    "  directed 1\n",
    "  node [\n",
    "    id 0\n",
    "    label \"a\"\n",
    "  ]\n",
    "  node [\n",
    "    id 1\n",
    "    label \"b\"\n",
    "  ]\n",
    "  node [\n",
    "    id 2\n",
    "    label \"c\"\n",
    "  ]\n",
    "  node [\n",
    "    id 3\n",
    "    label \"d\"\n",
    "  ]\n",
    "  node [\n",
    "    id 4\n",
    "    label \"e\"\n",
    "  ]\n",
    "  node [\n",
    "    id 5\n",
    "    label \"f\"\n",
    "  ]\n",
    "  node [\n",
    "    id 6\n",
    "    label \"g\"\n",
    "  ]\n",
    "  node [\n",
    "    id 7\n",
    "    label \"h\"\n",
    "  ]\n",
    "  node [\n",
    "    id 8\n",
    "    label \"i\"\n",
    "  ]\n",
    "  node [\n",
    "    id 9\n",
    "    label \"j\"\n",
    "  ]\n",
    "  edge [\n",
    "    source 0\n",
    "    target 1\n",
    "  ]\n",
    "  edge [\n",
    "    source 0\n",
    "    target 8\n",
    "  ]\n",
    "  edge [\n",
    "    source 1\n",
    "    target 2\n",
    "  ]\n",
    "  edge [\n",
    "    source 1\n",
    "    target 5\n",
    "  ]\n",
    "  edge [\n",
    "    source 2\n",
    "    target 3\n",
    "  ]\n",
    "  edge [\n",
    "    source 2\n",
    "    target 4\n",
    "  ]\n",
    "  edge [\n",
    "    source 3\n",
    "    target 4\n",
    "  ]\n",
    "  edge [\n",
    "    source 4\n",
    "    target 5\n",
    "  ]\n",
    "  edge [\n",
    "    source 5\n",
    "    target 6\n",
    "  ]\n",
    "  edge [\n",
    "    source 6\n",
    "    target 7\n",
    "  ]\n",
    "  edge [\n",
    "    source 7\n",
    "    target 8\n",
    "  ]\n",
    "  edge [\n",
    "    source 8\n",
    "    target 9\n",
    "  ]\n",
    "]\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Create Causal Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model=CausalModel(\n",
    "        data = df,\n",
    "        treatment=data[\"treatment_name\"],\n",
    "        outcome=data[\"outcome_name\"],\n",
    "        graph=graph_string\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.view_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "display(Image(filename=\"causal_model.png\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Testing for Conditional Independence\n",
    "We can check if the assumptions of the graph hold true for the data using<br> `model.refute_graph(k, independence_test = {'test_for_continuous': 'partial_correlation', 'test_for_discrete' : 'conditional_mutual_information'})` <br>\n",
    "We are testing X ⫫ Y | Z where X and Y are singular sets and Z can have k number of variables. k value is 1 by default unless input. <br>\n",
    "Currently we are using following settings - \n",
    "- \"partial_correlation\"  for continuous data<br>\n",
    "- \"conditional_mutual_information\" for discrete data<br>\n",
    "- \"conditional_mutua_information\" when Z is discrete and either of X and Y is continuous\n",
    "- \"partial_correlation\" when Z is continuous/binary and X and Y are either continuous/binary\n",
    "- \"conditional_mutual_information\" when X and Y are discrete and Z has mixed data\n",
    "- other settings are currently not supported"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "refuter_object = model.refute_graph(k=1, independence_test = {'test_for_continuous': 'partial_correlation', 'test_for_discrete' : 'conditional_mutual_information'}) #Change k parameter to test conditional independence given different number of variables "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(refuter_object)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing for a set of edges\n",
    "We can also test a set of conditional independences whether they are true or not The input has to be in the form - <br>\n",
    "[( x1, y1, (z1, z2)), <br>\n",
    " ( x2, y2, (z3, z4)),<br>\n",
    " ( x3, y3, (z5,)),<br>\n",
    " ( x4, y4, ())<br>\n",
    " ]<br>\n",
    " ##### The testing data can be a mix of discrete and continuous types as well (Here binary implies discrete only) -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "refuter_object = model.refute_graph(independence_constraints = [('c', 'e' , ('g',)), # c and e - binary, g - continuous\n",
    "                                                                ('f', 'h' , ('b',)), # f and h - discrete, b - continuous\n",
    "                                                                ('e', 'g' , ('h',)), # e - binary, g - continuous, h - discrete\n",
    "                                                                ('c', 'a' , ('b',)), # c and a - discrete, b - continuous\n",
    "                                                                ('d', 'i' , ('c',)), # d and i - continuous, c - binary\n",
    "                                                                ('a', 'j' , ())      # a - discrete, j - continuous\n",
    "                                                               ],\n",
    "                         independence_test = {'test_for_continuous': 'partial_correlation', 'test_for_discrete' : 'conditional_mutual_information'}\n",
    "                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(refuter_object)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing with a wrong graph input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_string = \"\"\"graph [\n",
    "        directed 1\n",
    "        node [\n",
    "            id 0\n",
    "            label \"a\"\n",
    "        ]\n",
    "        node [\n",
    "            id 1\n",
    "            label \"b\"\n",
    "        ]\n",
    "        node [\n",
    "            id 2\n",
    "            label \"c\"\n",
    "        ]\n",
    "        node [\n",
    "            id 3\n",
    "            label \"d\"\n",
    "        ]\n",
    "        node [\n",
    "            id 4\n",
    "            label \"e\"\n",
    "        ]\n",
    "        node [\n",
    "            id 5\n",
    "            label \"f\"\n",
    "        ]\n",
    "        node [\n",
    "            id 6\n",
    "            label \"g\"\n",
    "        ]\n",
    "        node [\n",
    "            id 7\n",
    "            label \"h\"\n",
    "        ]\n",
    "        node [\n",
    "            id 8\n",
    "            label \"i\"\n",
    "        ]\n",
    "        node [\n",
    "            id 9\n",
    "            label \"j\"\n",
    "        ]\n",
    "        edge [\n",
    "            source 0\n",
    "            target 1\n",
    "        ]\n",
    "        edge [\n",
    "            source 0\n",
    "            target 2\n",
    "        ]\n",
    "        edge [\n",
    "            source 0\n",
    "            target 3\n",
    "        ]\n",
    "        edge [\n",
    "            source 1\n",
    "            target 4\n",
    "        ]\n",
    "        edge [\n",
    "            source 1\n",
    "            target 5\n",
    "        ]\n",
    "        edge [\n",
    "            source 2\n",
    "            target 3\n",
    "        ]\n",
    "        edge [\n",
    "            source 4\n",
    "            target 2\n",
    "        ]\n",
    "        edge [\n",
    "            source 4\n",
    "            target 5\n",
    "        ]\n",
    "        edge [\n",
    "            source 4\n",
    "            target 6\n",
    "        ]\n",
    "        edge [\n",
    "            source 4\n",
    "            target 7\n",
    "        ]\n",
    "        edge [\n",
    "            source 8\n",
    "            target 6\n",
    "        ]\n",
    "        edge\n",
    "        [\n",
    "        source 9\n",
    "        target 0\n",
    "        ]\n",
    "        ]\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CausalModel(\n",
    "            data=df,\n",
    "            treatment=data[\"treatment_name\"],\n",
    "            outcome=data[\"outcome_name\"],\n",
    "            graph=graph_string,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.view_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "display(Image(filename=\"causal_model.png\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "refuter_object = model.refute_graph(k=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that since we input the wrong graph, many conditional independences were not met"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(refuter_object)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
